Dealing With PyTorch Custom Datasets.
In this article, we are going to take a look at How to deal with Custom PyTorch Dataset.
Photo by Joshua Earle on Unsplash
Custom datasets!! WHY??
Because you can shape it in a way you desire!!!
It is natural that we will develop our way of creating custom datasets while dealing with different Projects.
There are some official custom dataset examples on PyTorch Like here but it seemed a bit obscure to a beginner (like me, back then). The topics which we will discuss are as follows.
- Custom Dataset Fundamentals.
- Using Torchvision Transforms.
- Dealing with pandas (read_csv)
- Embedding Classes into File Names
- Using DataLoader
1. Custom Dataset Fundamentals.
A dataset must contain the following functions to be used by DataLoader later on.
__init__()
function, the initial logic happens here, like reading a CSV, assigning transforms, filtering data, etc.,
__getitem__()
returns the data and the labels.
__len__()
returns the count of samples your dataset has.
Now, the first part is to create a dataset class:
from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):
def __init__(self, ...):
# stuff
def __getitem__(self, index):
# stuff
return (img, label)
def __len__(self):
return count # of how many examples(images?) you have
Here, MyCustomDataset
returns two things, an image, and its label. But this doesn’t mean that __getitem__()
is only restricted to return those.
NOTE:
__getitem()
returns a specific type for a single data point (like a tensor), Otherwise, while loading the data you’ll get an error like,
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>
2. Using Torchvision Transforms:
In most of the examples, we will see transforms = None
in the __init__()
, it is to apply Torchvision transforms for our data/image. You can find the list of all transforms here.
Most Common usage of transforms are:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ..., transforms=None):
# stuff
...
self.transforms = transforms
def __getitem__(self, index):
# stuff
...
data = # Some data read from a file or image
if self.transforms is not None:
data = self.transforms(data)
# If the transform variable is not empty
# then it applies the operations in the transforms with the order that it is created.
return (img, label)
def __len__(self):
return count # of how many data(images?) you have
if __name__ == '__main__':
# Define transforms (1)
transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
# Call the dataset
custom_dataset = MyCustomDataset(..., transformations)
You can define the transforms inside the Dataset class. Like this:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ...):
# stuff
...
# (2) One way to do it is define transforms individually
self.center_crop = transforms.CenterCrop(100)
self.to_tensor = transforms.ToTensor()
# (3) Or you can still compose them like
self.transformations = \
transforms.Compose([transforms.CenterCrop(100),
transforms.ToTensor()])
def __getitem__(self, index):
# stuff
...
data = # Some data read from a file or image
# When you call transform for the second time it calls __call__() and applies the transform
data = self.center_crop(data) # (2)
data = self.to_tensor(data) # (2)
# Or you can call the composed version
data = self.transformations(data) # (3)
# Note that you only need one of the implementations,(2) or (3)
return (img, label)
def __len__(self):
return count # of how many data(images?) you have
if __name__ == '__main__':
# Call the dataset
custom_dataset = MyCustomDataset(...)
3. Dealing with Pandas(read_csv):
Now our Dataset contains a file name, label, and an extra operation indicator, we’ll perform some extra operation on the image.
+-----------+-------+-----------------+
| File Name | Label | Extra Operation |
+-----------+-------+-----------------+
| tr_0.png | 5 | TRUE |
| tr_1.png | 0 | FALSE |
| tr_1.png | 4 | FALSE |
+-----------+-------+-----------------+
Building a Custom dataset that reads image locations from this CSV.
class CustomDatasetFromImages(Dataset):
def __init__(self, csv_path):
"""
Args:
csv_path (string): path to csv file
img_path (string): path to the folder where images are
transform: pytorch transforms for transforms and tensor conversion
"""
# Transforms
self.to_tensor = transforms.ToTensor()
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# Second column is the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# Third column is for an operation indicator
self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
# Calculate len
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
# Get image name from the pandas df
single_image_name = self.image_arr[index]
# Open image
img_as_img = Image.open(single_image_name)
# Check if there is an operation
some_operation = self.operation_arr[index]
# If there is an operation
if some_operation:
# Do some operation on image
# ...
# ...
pass
# Transform image to tensor
img_as_tensor = self.to_tensor(img_as_img)
# Get label of the image based on the cropped pandas column
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
if __name__ == "__main__":
# Call dataset
custom_mnist_from_images = \
CustomDatasetFromImages('../data/mnist_labels.csv')
Another example of reading an image from CSV where the value of each pixel is listed in the Columns(Eg., MNIST). A little change of logic in __getitem__()
. In the end, we’ll just return images as Tensors and their labels. The data looks like this,
+-------+---------+---------+-----+
| Label | pixel_1 | pixel_2 | ... |
+-------+---------+---------+-----+
| 1 | 50 | 99 | ... |
| 0 | 21 | 223 | ... |
| 9 | 44 | 112 | |
+-------+---------+---------+-----+
Now, the code looks like:
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path, height, width, transforms=None):
"""
Args:
csv_path (string): path to csv file
height (int): image height
width (int): image width
transform: pytorch transforms for transforms and tensor conversion
"""
self.data = pd.read_csv(csv_path)
self.labels = np.asarray(self.data.iloc[:, 0])
self.height = height
self.width = width
self.transforms = transform
def __getitem__(self, index):
single_image_label = self.labels[index]
# Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28])
img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
# Convert image from numpy array to PIL image, mode 'L' is for grayscale
img_as_img = Image.fromarray(img_as_np)
img_as_img = img_as_img.convert('L')
# Transform image to tensor
if self.transforms is not None:
img_as_tensor = self.transforms(img_as_img)
# Return image and the label
return (img_as_tensor, single_image_label)
def __len__(self):
return len(self.data.index)
if __name__ == "__main__":
transformations = transforms.Compose([transforms.ToTensor()])
custom_mnist_from_csv = \
CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)
4. Embedding Class names as File Names:
Using Folder names of the images as the File_Names:
class CustomDatasetFromFile(Dataset):
def __init__(self, folder_path):
"""
A dataset example where the class is embedded in the file names
This data example also does not use any torch transforms
Args:
folder_path (string): path to image folder
"""
# Get image list
self.image_list = glob.glob(folder_path+'*')
# Calculate len
self.data_len = len(self.image_list)
def __getitem__(self, index):
# Get image name from the pandas df
single_image_path = self.image_list[index]
# Open image
im_as_im = Image.open(single_image_path)
# Do some operations on image
# Convert to numpy, dim = 28x28
im_as_np = np.asarray(im_as_im)/255
# Add channel dimension, dim = 1x28x28
# Note: You do not need to do this if you are reading RGB images
# or i there is already channel dimension
im_as_np = np.expand_dims(im_as_np, 0)
# Some preprocessing operations on numpy array
# ...
# ...
# ...
# Transform image to tensor, change data type
im_as_ten = torch.from_numpy(im_as_np).float()
# Get label(class) of the image based on the file name
class_indicator_location = single_image_path.rfind('_c')
label = int(single_image_path[class_indicator_location+2:class_indicator_location+3])
return (im_as_ten, label)
def __len__(self):
return self.data_len
5. Using DataLoader:
PyTorch DataLoaders will call __getitem__()
and wrap them up into a batch. But Technically, we will not use DataLoaders and call __getitem__()
one at a time and feed data into the models. Now, we can call the DataLoader like:
...
if __name__ == "__main__":
# Define transforms
transformations = transforms.Compose([transforms.ToTensor()])
# Define custom dataset
custom_mnist_from_csv = \
CustomDatasetFromCSV('../data/mnist_in_csv.csv',
28, 28,
transformations)
# Define data loader
mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
batch_size=10,
shuffle=False)
for images, labels in mn_dataset_loader:
# Feed the data to the model
Here, batch_size decides how many individual data points will be wrapped in a single batch. The DataLoader will return a Tensor of shape (Batch — Depth — Height — Width)
tensor.shape(10x1x28x28) # if batch_size =10 (For MNIST Data).
That’s it!!!
Custom Datasets!! No Worries!!
Reference: